#include "statsxx/machine_learning/RBM.hpp"

// STL
#include <cmath> // std::sqrt()

// jScience
#include "jScience/linalg.hpp" // Vector<>
#include "jrandnum.hpp"        // rand_num_normal_Mersenne_twister()

// stats++
#include "statsxx/distribution.hpp" // distribution::binomial()
#include "statsxx/machine_learning/activation_functions.hpp" // activation_function::Logistic, ::ReLU, ::softplus


// using namespace machine_learning;


//
// DESC:
//
// NOTE: a discussion of the (noisy) rectified linear unit, as implemented for ReLU and softplus is discussed in:
//     V. Nair and G. E. Hinton, Proceedings of the 27th International Conference on Machine Learning (2010)
//
inline Vector<double> machine_learning::RBM::sample(
                                                    const Vector<double> &x,
                                                    const Vector<double> &px,
                                                    const int             type // 0 == binary; 1 == continuous; 2 == ReLU; 3 == softplus
                                                    )
{
    Vector<double> y;
    
    // TODO: I think that the activation member functions could be made static ... then they could be called as commented out in this->calculate_p()
    activation_function::Logistic logistic;
    activation_function::ReLU     ReLU;
    activation_function::softplus softplus;
    
    switch(type)
    {
        case 0:
            y = distribution::binomial<double>(1, px);
            break;
        case 1:
            y = px;
            for(auto j = 0; j < y.size(); ++j)
            {
                // TODO: is the variance of the following best to keep always at 1?
                y(j) += rand_num_normal_Mersenne_twister(0.,1.);
            }
            break;
        case 2:
        case 3:
            y = x;
            for(auto j = 0; j < y.size(); ++j)
            {
                double sigmaxj = logistic.f(x(j));
                y(j) += rand_num_normal_Mersenne_twister(0.,std::sqrt(sigmaxj));
                // NOTE: in Hinton's "A Practical Guide to Training Restricted Boltzmann Machines", the NReLU variance is given as one, but in the original reference (and in fact in the guide), the variance is sigma(x) so that units that are firmly off do not create noise and the noise does not become large when x is large
            }
            if(type == 2)
            {
                y = ReLU.f(y);
            }
            else // if(type == 3)
            {
                y = softplus.f(y);
            }
            break;
        default:
            break;
    }
    
    return y;
}
